-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Improve mxnet support for activity classifier save/load #129
Conversation
aa1274b
to
5e09c7d
Compare
5e09c7d
to
ae3e948
Compare
context = _mxnet_utils.get_mxnet_context(max_devices=state['num_sessions']) | ||
state['_loss_model'] = _mxnet_utils.load_mxnet_model_from_state( | ||
state['_loss_model'], data, labels, None, context) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this mean we're no longer backward compatible?
If someone saved a model using version 1, the weights are now saved only in the loss model, and therefore when later in lines 301-303 when loading params from state['_pred_model']
they would be all zeros, won't they?
I can understand not being forward compatible (model saved in new version should not load in old version). But backwards compatibility is important.
We could check for if version==1 or '_loss_model' in state
then extract the params from loss model, else extract from pred model.
Right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review! In the current v4, there is no weight sharing when it gets saved to file. All weights are saved twice. Looking at the actual saved files, a model saved with v4 takes 4 MB while a model saved with v4+ takes 2 MB. Therefore, there is no problem for v4+ to simply ignore half of those weights and load the model entirely from the pred_model
.
Also, regarding backward compatibility. Every cell in the 6x6 matrix I showed in the original post is the result of an actual test and not just my hopes (I wanted to be very thorough!), so I have tested and verified full backward compatibility.
Please see my comment about backward compatibility. Otherwise LGTM. |
Sounds good, I will wait for @alonpal's review. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got a message from @alonpal. He's taking a flight so he can't get online, but he reviewed the changes and approves them.
This PR updates the activity classifier (AC) save/load to work with more versions of mxnet. There are still some issues with export to Core ML for newer versions of mxnet, so this PR alone does not expand support yet (#17).
The issue with the AC is that it saves and loads the network graph using mxnet json files. MXNet is pretty good at backward compatibility, but not forward compatibility. If we support more than one version of mxnet at a time, this creates a problem for us where we can't even be same-version compatible:
This won't work due to forward incompatibility in mxnet. The solution is to avoid saving and loading the graph and simply building it up and copying over the weights. This has much better forward compatibility.
Support (old and new)
Let's look at current support and support after this PR. I'll show my testing as matrices where:
Classifier (IC)/Similarity (IS)/Detector (OD): (all work on all combinations of save/load, although currently you get warnings if you do not use MXNet 0.11.0. This PR will eliminate those warnings. This broad support is thanks to the same changes that I'm making to the AC in this PR)
Activity Classifier (AC): (since this PR changes the AC saver/loader, I tested a bunch of combinations)
✅ Works
⚠️ Does not work (fails gracefully)
🚫 Does not work (ugly error)
"PR" refers to the Turi Create model as defined by this PR's commit. "1post1" is short for 1.0.0.post1 (1.0.0 segfaults the object detector, and this seems to have been resolved in the post1 version).
Top-left: This is the status quo
Right half: Backwards and same-version compatible
Bottom-left: Forward incompatible (with respect to TC version). See note at the bottom.
The "graceful" failure in 4.0 actually says "Corrupted model. Cannot load a model with this version." for OD/AC, and for IC/IS it does not even check the version! This PR in an isolated commit also improves this and makes the message friendlier and tells the user to upgrade Turi Create. Unfortunately, whenever we upgrade the file format for IC/IS, it will fail very ungracefully on 4.0.
Why not be forward compatible?
We could make newer models load in 4.0 as well. However, that is a commitment to write the backward migration to mxnet for all its future versions. For instance, in mxnet 5.0, we would still need to write json graphs that look like 0.11.0. It's better to break this compatibility now, since we would probably break it eventually. At least going forward, we have much better chances of being forward compatible (in mxnet version) for the AC, just like it turned out we are for IC/IS/OD.